package cn.edu.dgut.css.sai.security.oauth2.client.endpoint;

import cn.edu.dgut.css.sai.security.oauth2.core.http.converter.SaiCommonOAuth2AccessTokenResponseHttpMessageConverter;
import cn.edu.dgut.css.sai.security.oauth2.core.http.converter.SaiQQOAuth2AccessTokenResponseHttpMessageConverter;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.converter.FormHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequestEntityConverter;
import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;

import java.io.IOException;
import java.net.URI;
import java.util.*;

/**
 * (QQ、weixin、dgut、wxmp）OAuth2 登录时使用的 OAuth2AccessTokenResponseClient
 * <p>
 * 下面两个类非常重要:
 * <ul>
 * <li>{@link OAuth2AccessTokenResponseHttpMessageConverter}</li>
 * <li>{@link OAuth2AuthorizationCodeGrantRequestEntityConverter}</li>
 * </ul>
 *
 * @author Sai
 * @see OAuth2AuthorizationCodeGrantRequest
 * @see OAuth2AccessTokenResponse
 * @see DefaultAuthorizationCodeTokenResponseClient
 * @see OAuth2AccessTokenResponseClient
 * @see SaiCommonOAuth2AccessTokenResponseHttpMessageConverter
 * @see SaiQQOAuth2AccessTokenResponseHttpMessageConverter
 * @since 1.0.5
 */
public final class SaiCommonOAuth2AccessTokenResponseClient implements OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> {

    private final DefaultAuthorizationCodeTokenResponseClient authorizationCodeTokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient();
    private final OAuth2AccessTokenResponseHttpMessageConverter saiOAuth2AccessTokenResponseHttpMessageConverter = new SaiCommonOAuth2AccessTokenResponseHttpMessageConverter();
    private final OAuth2AccessTokenResponseHttpMessageConverter saiQQOAuth2AccessTokenResponseHttpMessageConverter = new SaiQQOAuth2AccessTokenResponseHttpMessageConverter();
    private final HttpMessageConverter<MultiValueMap<String, ?>> formHttpMessageConverter = new FormHttpMessageConverter();
    private final RestTemplate restTemplate = new RestTemplate();

    public SaiCommonOAuth2AccessTokenResponseClient() {
        List<MediaType> mediaTypes = new ArrayList<>();
        mediaTypes.addAll(saiOAuth2AccessTokenResponseHttpMessageConverter.getSupportedMediaTypes());
        mediaTypes.addAll(Arrays.asList(new MediaType("text", "html", Collections.singletonMap("charset", "utf-8")), MediaType.TEXT_PLAIN));
        saiOAuth2AccessTokenResponseHttpMessageConverter.setSupportedMediaTypes(mediaTypes);

        restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler());
        authorizationCodeTokenResponseClient.setRestOperations(restTemplate);
        authorizationCodeTokenResponseClient.setRequestEntityConverter(new SaiOAuth2AuthorizationCodeGrantRequestEntityConverter());
    }

    /**
     * 参考 {@code RestTemplate.HttpEntityRequestCallback#doWithRequest(ClientHttpRequest)} 第909行doWithRequest()方法
     */
    @Override
    public OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizationCodeGrantRequest authorizationGrantRequest) throws OAuth2AuthenticationException {
        if (authorizationGrantRequest.getClientRegistration().getRegistrationId().equals("qq")) {
            // 必须加formHttpMessageConverter,否则发送请求时requestBody写不了
            restTemplate.setMessageConverters(Arrays.asList(formHttpMessageConverter, saiQQOAuth2AccessTokenResponseHttpMessageConverter));
            // QQ的token接口没有返回openid,所以在这里通过另外的API获取openid，再把openid加进AdditionalParameters里
            OAuth2AccessTokenResponse tokenResponse = authorizationCodeTokenResponseClient.getTokenResponse(authorizationGrantRequest);
            String qqOpenid = getQQOpenid(tokenResponse.getAccessToken().getTokenValue());
            Map<String, Object> additionalParameters = new LinkedHashMap<>(tokenResponse.getAdditionalParameters());
            additionalParameters.put("openid", qqOpenid);
            return OAuth2AccessTokenResponse.withResponse(tokenResponse).additionalParameters(additionalParameters).build();
        } else
            restTemplate.setMessageConverters(Arrays.asList(formHttpMessageConverter, saiOAuth2AccessTokenResponseHttpMessageConverter));

        return authorizationCodeTokenResponseClient.getTokenResponse(authorizationGrantRequest);
    }

    private String getQQOpenid(String token) {
        try {
            HTTPRequest openidRquest = new HTTPRequest(HTTPRequest.Method.GET, toURI("https://graph.qq.com/oauth2.0/me").toURL());
            openidRquest.setQuery("access_token=" + token);
            openidRquest.setConnectTimeout(30000);
            openidRquest.setReadTimeout(30000);
            HTTPResponse openidResponse = openidRquest.send();

            // callback( {"client_id":"YOUR_APPID","openid":"YOUR_OPENID"} );
            // 截取json字符串
            String openidJsonStr = "{" + StringUtils.substringBefore(StringUtils.substringAfter(openidResponse.getContent(), "{"), "}") + "}";
            // 把返回的json字符串转换为map<String,?>
            ObjectMapper objectMapper = new ObjectMapper();
            Map<String, ?> map = objectMapper.readValue(openidJsonStr, new TypeReference<>() {
            });

            if (!map.containsKey("openid") || map.containsKey("error") || map.get("openid") == null)
                throw new IllegalArgumentException("Could not get the QQ openid:" + map.get("error"));

            return map.get("openid").toString();
        } catch (IOException e) {
            throw new AuthenticationServiceException("An error occurred while getting the Openid Request: " +
                    e.getMessage(), e);
        }
    }

    @SuppressWarnings("SameParameterValue")
    private URI toURI(String uriStr) {
        try {
            return new URI(uriStr);
        } catch (Exception ex) {
            throw new IllegalArgumentException("An error occurred parsing URI: " + uriStr, ex);
        }
    }

    /**
     * 自定义token request转换器{@link OAuth2AuthorizationCodeGrantRequest}
     *
     * @author Sai
     * @since 1.0.5
     */
    private static final class SaiOAuth2AuthorizationCodeGrantRequestEntityConverter extends OAuth2AuthorizationCodeGrantRequestEntityConverter {
        @SuppressWarnings("ConstantConditions")
        @Override
        public RequestEntity<?> convert(OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) {
            RequestEntity<?> request = super.convert(authorizationCodeGrantRequest);
            return new RequestEntity<>(reBulidFormParameters(request, authorizationCodeGrantRequest), request.getHeaders(), request.getMethod(), request.getUrl());
        }

        /**
         * 根据QQ、weixin、Dgut的token接口要求，重构请求参数
         *
         * @param request                       {@link RequestEntity}
         * @param authorizationCodeGrantRequest {@link OAuth2AuthorizationCodeGrantRequest}
         * @return 新请求参数
         */
        @SuppressWarnings("unchecked")
        private MultiValueMap<String, String> reBulidFormParameters(RequestEntity<?> request, OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) {
            MultiValueMap<String, String> formParameters = new LinkedMultiValueMap<>();
            formParameters.addAll((MultiValueMap<String, String>) Objects.requireNonNull(Objects.requireNonNull(request).getBody()));

            switch (authorizationCodeGrantRequest.getClientRegistration().getRegistrationId()) {
                case "dgut":
                    formParameters.add("token", authorizationCodeGrantRequest.getAuthorizationExchange().getAuthorizationResponse().getCode());
                    formParameters.add("appid", authorizationCodeGrantRequest.getClientRegistration().getClientId());
                    formParameters.add("appsecret", authorizationCodeGrantRequest.getClientRegistration().getClientSecret());
                    formParameters.add("userip", (String) authorizationCodeGrantRequest.getAuthorizationExchange().getAuthorizationRequest().getAdditionalParameters().get("ip"));
                    break;
                case "wxmp":
                case "weixin":
                    // 微信的token接口使用的参数名为appid和secret，不是oauth2标准中的client_id，这里不作修改，直接增加参数。
                    formParameters.add("appid", authorizationCodeGrantRequest.getClientRegistration().getClientId());
                    formParameters.add("secret", authorizationCodeGrantRequest.getClientRegistration().getClientSecret());
                    break;
                case "qq":
                    break;
            }

            return formParameters;
        }
    }

}
